package crypto

import (
	"encoding/base64"
	"encoding/hex"
	"fmt"
	"math/rand"
	"time"
)

const (
	sign   = "20200629170609"
	answer = "It's right!"
)

var (
	key    []byte   // = []byte(MD5s(sign, cryptoTag, answer))
	nonces [][]byte // = ([]byte(MD5f("%s:nonce:%s:%s", cryptoTag, answer, sign)))[:24]
)

func InitCrypto(salt string) error {
	rand.Seed(time.Now().UnixNano())
	key = []byte(MD5s(sign, salt, answer))
	for n := 0; n < 256; n++ {
		nonce, err := hex.DecodeString(MD5f("%s:AES-Nonce%s:%s-%d", sign, salt, answer, n)[:24])
		if err != nil {
			return err
		}
		nonces = append(nonces, nonce)
	}
	return nil
}

func EncryptInt64(val int64, key int64) ([]byte, error) {
	// var nIdx byte = byte(rand.Intn(256))
	var keyArray = make([]byte, 32)
	putInt64(keyArray, key)
	data, err := AESEncryptInt64(val, keyArray, nonces[0])
	if err != nil {
		return nil, err
	}
	// data = append([]byte{nIdx}, data...)
	buf := make([]byte, base64.URLEncoding.EncodedLen(len(data)))
	base64.URLEncoding.Encode(buf, data)
	return buf, nil
	// return base64.URLEncoding.Encode(data), nil
	// return base64.URLEncoding.EncodeToString(data), nil
}

func DecryptInt64(data []byte, key int64) (int64, error) {
	buf := make([]byte, base64.URLEncoding.DecodedLen(len(data)))
	n, err := base64.URLEncoding.Decode(buf, data)
	if err != nil {
		return 0, fmt.Errorf("base64 decode error: %w", err)
	}
	// return buf[:n], err
	// buf, err := base64.URLEncoding.DecodeString(data)
	// if err != nil {
	// 	return 0, fmt.Errorf("base64 decode error: %w", err)
	// }
	var keyArray = make([]byte, 32)
	putInt64(keyArray, key)
	return AESDecryptInt64(buf[:n], keyArray, nonces[0])
}

func EncryptString(obj string) (string, error) {
	// var nIdx byte = byte(rand.Intn(256))
	data, err := AESEncryptString(obj, key, nonces[0])
	if err != nil {
		return "", err
	}
	// data = append([]byte{nIdx}, data...)
	return base64.URLEncoding.EncodeToString(data), nil
}

func DecryptString(data string) (string, error) {
	buf, err := base64.URLEncoding.DecodeString(data)
	if err != nil {
		return "", fmt.Errorf("base64 decode error: %w", err)
	}
	return AESDecryptString(buf, key, nonces[0])
	// return AESDecryptObject(buf[1:], key, nonces[buf[0]], obj)
}

func EncryptObjectRandomNonce(obj interface{}) ([]byte, error) {
	var nIdx byte = byte(rand.Intn(256))
	data, err := AESEncyptObject(obj, key, nonces[nIdx])
	if err != nil {
		return nil, err
	}
	data = append([]byte{nIdx}, data...)
	buf := make([]byte, base64.URLEncoding.EncodedLen(len(data)))
	base64.URLEncoding.Encode(buf, data)
	return buf, nil
}

func DecryptObjectRandomNonce(data []byte, obj interface{}) error {
	buf := make([]byte, base64.URLEncoding.DecodedLen(len(data)))
	n, err := base64.URLEncoding.Decode(buf, data)
	if err != nil {
		return fmt.Errorf("base64 decode error: %w", err)
	}
	return AESDecryptObject(buf[1:n], key, nonces[buf[0]], obj)
}

func EncryptObject(obj interface{}) ([]byte, error) {
	// var nIdx byte = byte(rand.Intn(256))
	data, err := AESEncyptObject(obj, key, nonces[0])
	if err != nil {
		return nil, err
	}
	buf := make([]byte, base64.URLEncoding.EncodedLen(len(data)))
	base64.URLEncoding.Encode(buf, data)
	return buf, nil
	// data = append([]byte{nIdx}, data...)
	// return base64.URLEncoding.EncodeToString(data), nil
}

func DecryptObject(data []byte, obj interface{}) error {
	buf := make([]byte, base64.URLEncoding.DecodedLen(len(data)))
	n, err := base64.URLEncoding.Decode(buf, data)
	// buf, err := base64.URLEncoding.DecodeString(data)
	if err != nil {
		return fmt.Errorf("base64 decode error: %w", err)
	}
	return AESDecryptObject(buf[:n], key, nonces[0], obj)
}

func putInt64(b []byte, v int64) {
	_ = b[31] // early bounds check to guarantee safety of writes below

	b[0] = byte(v >> 56)
	b[1] = byte(v >> 48)
	b[2] = byte(v >> 40)
	b[3] = byte(v >> 32)
	b[4] = byte(v >> 24)
	b[5] = byte(v >> 16)
	b[6] = byte(v >> 8)
	b[7] = byte(v)

	b[8] = byte(v >> 56)
	b[9] = byte(v >> 48)
	b[10] = byte(v >> 40)
	b[11] = byte(v >> 32)
	b[12] = byte(v >> 24)
	b[13] = byte(v >> 16)
	b[14] = byte(v >> 8)
	b[15] = byte(v)

	b[16] = byte(v >> 56)
	b[17] = byte(v >> 48)
	b[18] = byte(v >> 40)
	b[19] = byte(v >> 32)
	b[20] = byte(v >> 24)
	b[21] = byte(v >> 16)
	b[22] = byte(v >> 8)
	b[23] = byte(v)

	b[24] = byte(v >> 56)
	b[25] = byte(v >> 48)
	b[26] = byte(v >> 40)
	b[27] = byte(v >> 32)
	b[28] = byte(v >> 24)
	b[29] = byte(v >> 16)
	b[30] = byte(v >> 8)
	b[31] = byte(v)
}
